import os
import time
import argparse
import torch
import matplotlib.pyplot as plt

from PIL import Image
from model.model import Yolov2
from model.utils.weight_loader import load_weight
from model.utils.eval import prepare_im_data, yolo_eval
from model.utils.visualize import draw_detection_boxes
from model.utils.names import CLASSES

NUM_CLASSES = 20
NUM_ANCHORS = 5


def main(args):
    model = Yolov2(num_classes=NUM_CLASSES, num_anchors=NUM_ANCHORS).to(args.device)

    if not os.path.isabs(args.model):
        args.model = os.path.join(os.getcwd(), args.model)

    if args.type == "pytorch":
        if os.path.exists(args.model):
            checkpoint = torch.load(args.model)
            model.load_state_dict(checkpoint["model"])
        else:
            raise RuntimeError("No model file")
    elif args.type == "darknet":
        load_weight(model, args.model)
    else:
        raise RuntimeError("Model type error")
    print("model loaded!")

    img = Image.open(args.img)
    im_data, im_info = prepare_im_data(img)
    im_data = im_data.to(args.device)

    tic = time.time()

    model.eval()
    yolo_output = model(im_data)
    detections = yolo_eval(yolo_output, im_info, conf_threshold=0.6, nms_threshold=0.4)

    toc = time.time()
    cost_time = toc - tic
    print("im detect, cost time {:4f}, FPS: {}".format(toc - tic, int(1 / cost_time)))

    if len(detections) > 0:
        det_boxes = detections[:, :5].cpu().numpy()
        det_classes = detections[:, -1].long().cpu().numpy()
        im2show = draw_detection_boxes(img, det_boxes, det_classes, class_names=CLASSES)
        plt.figure()
        plt.imshow(im2show)
        plt.show()
    else:
        print("No objects detected!")


def parse_args():
    parser = argparse.ArgumentParser(description="Yolo(v2) Inference")

    # fmt: off
    parser.add_argument("img", help="Image path, include image file, dir and URL")
    parser.add_argument("--device", default='cuda:0', type=str, help="Device used for inference")
    parser.add_argument('--model', default="output/yolov2_epoch_160.pth", type=str)
    parser.add_argument('--type', default="pytorch", type=str)
    # fmt: on

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)
